﻿using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Data.Entity.Core.Objects;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Text;

namespace WooCoo.Data.EF
{
    /// <summary>
    /// 该自动控制EF DbContext在上下文环境中的重用，并可利用DbContext的事务，以达天IUnitOfWork的效果
    /// </summary>
    /// <typeparam name="C">
    /// Type of database 
    /// DbContext object to use.
    /// </typeparam>
    /// <remarks>
    /// This type stores the object context object 
    /// in <see cref="Csla.ApplicationContext.LocalContext" />
    /// and uses reference counting through
    /// <see cref="IDisposable" /> to keep the data context object
    /// open for reuse by child objects, and to automatically
    /// dispose the object when the last consumer
    /// has called Dispose."
    /// </remarks>
    public class DbContextManager<C> : IDisposable where C : DbContext
    {
        private static object _lock = new object();
        private C _context;
        private string _label;
        private string ContextLabel { get; set; }

        /// <summary>
        /// Gets the ObjectContextManager object for the 
        /// specified database.
        /// </summary>
        /// <param name="database">
        /// The database name or connection string.
        /// </param>
        /// <param name="label">Label for this context.</param>
        /// <param name="model">Database Compiled model.</param>
        /// <returns>ContextManager object for the name.</returns>
        public static DbContextManager<C> GetManager(string database, string label, DbCompiledModel model)
        {
            lock (_lock)
            {
                var contextLabel = GetContextName(database, label);
                DbContextManager<C> mgr = null;
                if (ApplicationContext.LocalContext.Contains(contextLabel))
                {
                    mgr = (DbContextManager<C>)(ApplicationContext.LocalContext[contextLabel]);
                }
                else
                {
                    mgr = new DbContextManager<C>(database, label, model, null);
                    mgr.ContextLabel = contextLabel;
                    ApplicationContext.LocalContext[contextLabel] = mgr;
                }
                mgr.AddRef();
                return mgr;
            }
        }

        private DbContextManager(string database, string label, DbCompiledModel model, ObjectContext context)
        {
            _label = label;

            if (model != null)
            {
                if (!string.IsNullOrEmpty(database))
                    _context = (C)(Activator.CreateInstance(typeof(C), database, model));
                else
                    _context = (C)(Activator.CreateInstance(typeof(C), model));
            }
            else if (context != null)
            {
                _context = (C)(Activator.CreateInstance(typeof(C), context, true));
            }
            else if (string.IsNullOrEmpty(database))
            {
                _context = (C)(Activator.CreateInstance(typeof(C)));
            }
            else
            {
                _context = (C)(Activator.CreateInstance(typeof(C), database));
            }
        }


        /// <summary>
        /// Gets the ObjectContextManager object for the     
        /// specified database.
        /// </summary>

        public static DbContextManager<C> GetManager()
        {
            return GetManager(string.Empty);
        }

        /// <summary>
        /// Gets the ObjectContextManager object for the 
        /// specified database.
        /// </summary>
        /// <param name="database">Database name as shown in the config file.</param>
        public static DbContextManager<C> GetManager(string database)
        {
            return GetManager(database, "default", null);
        }

        /// <summary>
        /// Gets the ObjectContextManager object for the 
        /// specified database.
        /// </summary>
        /// <param name="database">Database name as shown in the config file.</param>
        /// <param name="label">Label for this context.</param>
        public static DbContextManager<C> GetManager(string database, string label)
        {
            return GetManager(database, label, null);
        }

        /// <summary>
        /// Gets the ObjectContextManager object for the 
        /// specified database.
        /// </summary>
        /// <param name="database">Database name as shown in the config file.</param>
        /// <param name="model">Database compiled model.</param>
        public static DbContextManager<C> GetManager(string database, DbCompiledModel model)
        {
            return GetManager(database, "default", model);
        }
        /// <summary>
        /// Gets the ObjectContextManager object for the 
        /// specified database.
        /// </summary>
        /// <param name="context">
        /// The ObjectContext to wrap in DbContext
        /// </param>
        public static DbContextManager<C> GetManager(ObjectContext context)
        {
            return GetManager(context, "default");
        }

        /// <summary>
        /// Gets the ObjectContextManager object for the 
        /// specified database.
        /// </summary>
        /// <param name="context">
        /// The ObjectContext to wrap in DbContext
        /// </param>
        /// <param name="label">Label for this context.</param>
        public static DbContextManager<C> GetManager(ObjectContext context, string label)
        {
            lock (_lock)
            {
                var contextLabel = GetContextName(context.DefaultContainerName, label);
                DbContextManager<C> mgr = null;
                if (ApplicationContext.LocalContext.Contains(contextLabel))
                {
                    mgr = (DbContextManager<C>)(ApplicationContext.LocalContext[contextLabel]);
                }
                else
                {
                    mgr = new DbContextManager<C>(null, label, null, context);
                    mgr.ContextLabel = contextLabel;
                    ApplicationContext.LocalContext[contextLabel] = mgr;
                }
                mgr.AddRef();
                return mgr;
            }
        }

        private static string GetContextName(string database, string label)
        {
            return "__dbctx:" + label + "-" + database;
        }


        /// <summary>
        /// Gets the DbContext object.
        /// </summary>
        public C DbContext
        {
            get
            {
                return _context;
            }
        }

        #region  Reference counting

        private int _refCount;

        /// <summary>
        /// Gets the current reference count for this
        /// object.
        /// </summary>
        public int RefCount
        {
            get { return _refCount; }
        }

        private void AddRef()
        {
            _refCount += 1;
        }

        private void DeRef()
        {

            lock (_lock)
            {
                _refCount -= 1;
                if (_refCount == 0)
                {
                    _context.Dispose();
                    ApplicationContext.LocalContext.Remove(ContextLabel);
                }
            }

        }

        #endregion
        /// <summary>
        /// 事务提交，当引用数为1时，表示可提交事务。事务的提交将在最外层Using中
        /// </summary>
        public int Commit()
        {
            lock (_lock)
            {
                //_refCount -= 1;
                if (_refCount == 1)
                {
                    return _context.SaveChanges();
                }
                return 0;
            }
        }

        #region  IDisposable

        /// <summary>
        /// Dispose object, dereferencing or
        /// disposing the context it is
        /// managing.
        /// </summary>
        public void Dispose()
        {
            DeRef();
        }

        #endregion

    }
}
